import os
import sys

from PIL import Image

sys.path.append(os.pardir)
import numpy as np
from dataset.mnist import load_mnist


def imgshow(img):
    pil_img = Image.fromarray(np.uint8(img))
    pil_img.show()


# 读入dataset，拿到训练集和测试集的数据和标签
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=False, flatten=True, one_hot_label=False)

# 开始测试数据
img = x_train[0]
label = t_train[0]  # 拿到训练集第一个数据的标签
print(label)

print(img.shape)  # 图片已被搞成一维的，将他们二维化
img = img.reshape(28, 28)
print(img.shape)

imgshow(img)
